Skip to content

Conversation

@wdykas
Copy link
Contributor

@wdykas wdykas commented Nov 6, 2025

Description

I want to be able to control num splits in FA3. This exposes this argument for non-context-parallel cases.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR exposes the num_splits parameter for FlashAttention v2 and v3 backends, allowing users to control memory optimization during attention computation.

Key Changes:

  • Added optional num_splits parameter to DotProductAttention.forward() method
  • Passes num_splits to both FlashAttention v2 and v3 backend implementations when provided
  • Parameter is conditionally added to kwargs only when not None

Areas for Improvement:

  • Missing parameter documentation in the docstring
  • No version compatibility check for flash-attn (unlike other optional parameters like window_size and deterministic)
  • No tests demonstrating the new functionality

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • The implementation correctly follows the existing pattern for optional parameters in FlashAttention backends. The changes are minimal and well-scoped. However, the score is not 5 due to: (1) missing parameter documentation, (2) lack of version compatibility checks that other optional parameters have, and (3) no accompanying tests. These are quality-of-life improvements rather than critical issues.
  • No files require special attention - the implementation is straightforward and follows existing patterns

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py 4/5 Added num_splits parameter to FlashAttention forward method and passes it to both FA v2 and FA v3 backends when provided
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 4/5 Added num_splits parameter to DotProductAttention forward signature and forwards it to FlashAttention backend

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant FlashAttention
    participant FlashAttn_v2
    participant FlashAttn_v3

    User->>DotProductAttention: forward(query, key, value, num_splits=...)
    Note over DotProductAttention: New optional parameter
    DotProductAttention->>FlashAttention: forward(..., num_splits=num_splits)
    Note over FlashAttention: Check backend version
    alt Flash Attention v2
        FlashAttention->>FlashAttention: if num_splits is not None
        FlashAttention->>FlashAttention: fa_optional_forward_kwargs["num_splits"] = num_splits
        FlashAttention->>FlashAttn_v2: flash_attn_func(..., **fa_optional_forward_kwargs)
        FlashAttn_v2-->>FlashAttention: output
    else Flash Attention v3
        FlashAttention->>FlashAttention: if num_splits is not None
        FlashAttention->>FlashAttention: fa_3_optional_forward_kwargs["num_splits"] = num_splits
        FlashAttention->>FlashAttn_v3: flash_attn_func(..., **fa_3_optional_forward_kwargs)
        FlashAttn_v3-->>FlashAttention: output
    end
    FlashAttention-->>DotProductAttention: output
    DotProductAttention-->>User: output
Loading

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
num_splits: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing documentation for the num_splits parameter in the docstring. Add a parameter description explaining what this controls and when it should be used (e.g., "Controls the number of splits for FlashAttention computation. Used for memory optimization.")

Comment on lines +928 to +941
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Verify that the flash-attn version supports num_splits parameter. Unlike other optional parameters (e.g., window_size has fa_utils.v2_3_plus check, deterministic has fa_utils.v2_4_1_plus check), this parameter is added without a version guard.

Comment on lines +965 to +978
if num_splits is not None:
# Forward optional split control to flash-attn v3 if supported
fa_3_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same as FA v2: verify flash-attn v3 supports num_splits to avoid potential TypeError

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 10, 2025

Greptile Overview

Greptile Summary

This PR adds support for controlling the num_splits parameter in FlashAttention-3 for non-context-parallel cases. The implementation adds the parameter to the DotProductAttention.forward() method and properly validates that it's only used with FA3 backends.

Key Changes:

  • Added num_splits parameter to DotProductAttention.forward() with comprehensive docstring
  • Implemented dual-layer validation (at both DotProductAttention and FlashAttention levels) to ensure num_splits is only used with FA3
  • Forwards num_splits to FA3 backend functions (flash_attn_func_v3, flash_attn_varlen_func_v3, flash_attn_with_kvcache_v3)
  • Added test case test_dpa_num_splits that verifies the feature works with FA3

Issues Found:

  • Dead code at backends.py:939-941 - the check if num_splits is not None: inside the FA2 path can never be true due to earlier validation
  • Previous comments about missing version guards appear to be false positives - the validation at line 848 correctly prevents num_splits from being passed to FA2

Confidence Score: 4/5

  • Safe to merge with minor code cleanup suggested
  • The implementation is functionally correct with proper validation ensuring num_splits is only used with FA3. The dead code at line 939-941 doesn't affect correctness but should be removed for clarity. Tests adequately cover the new functionality.
  • backends.py:939-941 contains unreachable dead code that should be removed for code clarity

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 4/5 Added num_splits parameter with proper validation to ensure it's only used with FA3
transformer_engine/pytorch/attention/dot_product_attention/backends.py 3/5 Forwards num_splits to FA2/FA3 backends with validation, but contains unreachable dead code in FA2 path
tests/pytorch/attention/test_attention.py 5/5 Added test for num_splits with proper FA3 availability check and backend filtering

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant FlashAttention
    participant FA3/FA2

    User->>DotProductAttention: forward(query, key, value, num_splits=2)
    DotProductAttention->>DotProductAttention: Validate num_splits requires FA3
    alt FA3 not selected and num_splits set
        DotProductAttention-->>User: ValueError (FA3 required)
    end
    DotProductAttention->>FlashAttention: forward(..., num_splits=2)
    FlashAttention->>FlashAttention: Check use_flash_attn_3 flag
    alt num_splits set and not FA3
        FlashAttention-->>DotProductAttention: ValueError (FA3 required)
    end
    alt FA3 enabled
        FlashAttention->>FA3: flash_attn_*_v3(..., num_splits=2)
        FA3-->>FlashAttention: output
    else FA2 path (num_splits must be None)
        FlashAttention->>FA2: flash_attn_func(...)
        FA2-->>FlashAttention: output
    end
    FlashAttention-->>DotProductAttention: output
    DotProductAttention-->>User: output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +938 to +941
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Unlike other optional parameters (window_size has fa_utils.v2_3_plus, deterministic has fa_utils.v2_4_1_plus), num_splits is added to FA2 kwargs without version checking. If flash-attn v2 doesn't support this parameter, this will cause a TypeError at runtime.

Suggested change
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
if num_splits is not None:
# Only add num_splits if flash-attn supports it (check version if needed)
# TODO: Add version check once minimum flash-attn version with num_splits is determined
fa_optional_forward_kwargs["num_splits"] = num_splits

@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
@pytest.mark.parametrize("num_splits", [None, 2])
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout, num_splits):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the num_splits a separate test, instead of piggybacking on the max_logit test :) You can still call test_dot_product_attention in it the same way you do here. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done I think

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator

Could you please follow the instructions here to fix the DCO? Thanks!

@cyanguwa
Copy link
Collaborator

/te-ci pytorch L0

Oleg-Goncharov and others added 21 commits November 12, 2025 08:32
Deleted unused header

Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
…#2321)

* L1 rework

Signed-off-by: Phuong Nguyen <[email protected]>

* comment out test_multi_process_grouped_gemm for now

Signed-off-by: Phuong Nguyen <[email protected]>

* rm e5m2 from test norm + MXFP8

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix:

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* Try to use pre-downloaded dataset artifacts first

Signed-off-by: Jeremy Berchtold <[email protected]>

* Set HF_HUB_OFFLINE to disable any network calls to HF when the
pre-downloaded dataset is available

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* Make cast_master_weights_to_fp8 compatible with older MCore version

Signed-off-by: kunlunl <[email protected]>

* Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test

Signed-off-by: kunlunl <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove redundant _test_mini_optimizer()

Signed-off-by: kunlunl <[email protected]>

---------

Signed-off-by: kunlunl <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
…VIDIA#2348)

* Add test to check jaxpr that amax is reused for nvfp4 recipe

Signed-off-by: Jeremy Berchtold <[email protected]>

* Move test to test_helper.py and rename file

Signed-off-by: Jeremy Berchtold <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
* Fix cuDNN backend selection for more case. Add CG as a option as well

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix logic

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cuDNN checks

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add more checks

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix cuddn version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix error message

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Add check for window size

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* Default to fused attention in JAX DPA

Signed-off-by: Kshitij Lakhani <[email protected]>

* Consolidate documentation for DPA in JAX

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <[email protected]>

* Correctly update the documentation for defaults in JAX DPA

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kshitij Lakhani <[email protected]>

---------

Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
…kernel type. (NVIDIA#2287)

* code drop

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <[email protected]>

* depracted compile time warning + \warning -> \deprecated

Signed-off-by: Pawel Gadzinski <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* move triton to common and change paths

Signed-off-by: tdophung <[email protected]>

* Formatting

Signed-off-by: tdophung <[email protected]>

---------

Signed-off-by: tdophung <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False

Signed-off-by: tdophung <[email protected]>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values

Signed-off-by: tdophung <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: tdophung <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Peter Dykas <[email protected]>
Signed-off-by: Peter Dykas <[email protected]>
@wdykas wdykas force-pushed the num-splits-attention branch from 2c90ec0 to 52a5f37 Compare November 12, 2025 16:32
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

57 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +848 to +856
if num_splits is not None and not use_flash_attn_3:
if not fa_utils.v3_is_installed:
raise ValueError(
"num_splits is only supported with FlashAttention-3, which is not installed. "
)
raise ValueError(
"num_splits is only supported with FlashAttention-3. "
"Please adjust configuration to enable FA3 for these inputs."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The validation logic here correctly prevents num_splits from being used with non-FA3 backends. However, there's a potential issue with code maintainability: later in the code (line 939), num_splits is added to fa_optional_forward_kwargs without a version check, unlike other optional parameters. While the current validation prevents reaching that code with num_splits set and FA2, this creates fragile coupling between distant code sections. Consider either: (1) adding a version guard at line 939 similar to window_size and deterministic, or (2) adding an assertion that num_splits is None in the FA2 branch.

@wdykas
Copy link
Contributor Author

wdykas commented Nov 12, 2025

Could you please follow the instructions here to fix the DCO? Thanks!

I think this is done?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +939 to +941
if num_splits is not None:
# Forward optional split control to flash-attn if available
fa_optional_forward_kwargs["num_splits"] = num_splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This is unreachable dead code. The validation at line 848 already ensures num_splits is None when use_flash_attn_3 is False, so this condition can never be true inside the if not use_flash_attn_3: block. Consider removing this check for code clarity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.